{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load libraries\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from res.plot_lib import set_default, show_scatterplot, plot_bases\n",
    "from matplotlib.pyplot import plot, title, axis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set style (needs to be in a new cell)\n",
    "set_default()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate some points in 2-D space\n",
    "n_points = 1000\n",
    "X = torch.randn(n_points, 2).to(device)\n",
    "colors = X[:, 0]\n",
    "\n",
    "show_scatterplot(X, colors, title='X')\n",
    "OI = torch.cat((torch.zeros(2, 2), torch.eye(2))).to(device)\n",
    "plot_bases(OI)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualizing Linear Transformations\n",
    "\n",
    "* Generate a random matrix $W$\n",
    "\n",
    "$\n",
    "\\begin{equation}\n",
    "    W = U\n",
    "  \\left[ {\\begin{array}{cc}\n",
    "   s_1 & 0 \\\\\n",
    "   0 & s_2 \\\\\n",
    "  \\end{array} } \\right]\n",
    "  V^\\top\n",
    "\\end{equation}\n",
    "$\n",
    "* Compute $y = Wx$\n",
    "* Larger singular values stretch the points\n",
    "* Smaller singular values push them together\n",
    "* $U, V$ rotate/reflect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_scatterplot(X, colors, title='X')\n",
    "plot_bases(OI)\n",
    "\n",
    "for i in range(10):\n",
    "    # create a random matrix\n",
    "    W = torch.randn(2, 2).to(device)\n",
    "    # transform points\n",
    "    Y = X @ W.t()\n",
    "    # compute singular values\n",
    "    U, S, V = torch.svd(W)\n",
    "    # plot transformed points\n",
    "    show_scatterplot(Y, colors, title='y = Wx, singular values : [{:.3f}, {:.3f}]'.format(S[0], S[1]))\n",
    "    # transform the basis\n",
    "    new_OI = OI @ W\n",
    "    # plot old and new basis\n",
    "    plot_bases(OI)\n",
    "#     plot_bases(new_OI)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Linear transformation with PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nn.Sequential(\n",
    "        nn.Linear(2, 2, bias=False)\n",
    ")\n",
    "model.to(device)\n",
    "with torch.no_grad():\n",
    "    Y = model(X)\n",
    "    show_scatterplot(Y, colors)\n",
    "    plot_bases(model(OI))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Non-linear Transform: Map Points to a Square\n",
    "\n",
    "* Linear transforms can rotate, reflect, stretch and compress, but cannot curve\n",
    "* We need non-linearities for this\n",
    "* Can (approximately) map points to a square by first stretching out by a factor $s$, then squashing with a tanh function\n",
    "\n",
    "$\n",
    "   f(x)= \\tanh \\left(\n",
    "  \\left[ {\\begin{array}{cc}\n",
    "   s & 0 \\\\\n",
    "   0 & s \\\\\n",
    "  \\end{array} } \\right]  \n",
    "  x\n",
    "  \\right)\n",
    "$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = torch.linspace(-10, 10, 101)\n",
    "s = torch.tanh(z)\n",
    "plot(z.numpy(), s.numpy())\n",
    "title('tanh() non linearity');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_scatterplot(X, colors, title='X')\n",
    "plot_bases(OI)\n",
    "\n",
    "model = nn.Sequential(\n",
    "        nn.Linear(2, 2, bias=False),\n",
    "        nn.Tanh()\n",
    ")\n",
    "\n",
    "model.to(device)\n",
    "\n",
    "for s in range(1, 6):\n",
    "    W = s * torch.eye(2)\n",
    "    model[0].weight.data.copy_(W)\n",
    "    Y = model(X).data\n",
    "    show_scatterplot(Y, colors, title=f'f(x), s={s}')\n",
    "    plot_bases(OI, width=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize Functions Represented by Random Neural Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_scatterplot(X, colors, title='x')\n",
    "n_hidden = 5\n",
    "\n",
    "# NL = nn.ReLU()\n",
    "NL = nn.Tanh()\n",
    "\n",
    "for i in range(5):\n",
    "    # create 1-layer neural networks with random weights\n",
    "    model = nn.Sequential(\n",
    "            nn.Linear(2, n_hidden), \n",
    "            NL, \n",
    "            nn.Linear(n_hidden, 2)\n",
    "        )\n",
    "    model.to(device)\n",
    "    with torch.no_grad():\n",
    "        Y = model(X)\n",
    "    show_scatterplot(Y, colors, title='f(x)')\n",
    "#     plot_bases(OI)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# deeper network with random weights\n",
    "show_scatterplot(X, colors, title='x')\n",
    "n_hidden = 5\n",
    "\n",
    "# NL = nn.ReLU()\n",
    "NL = nn.Tanh()\n",
    "\n",
    "for i in range(5):\n",
    "    model = nn.Sequential(\n",
    "        nn.Linear(2, n_hidden), \n",
    "        NL, \n",
    "        nn.Linear(n_hidden, n_hidden), \n",
    "        NL, \n",
    "        nn.Linear(n_hidden, n_hidden), \n",
    "        NL, \n",
    "        nn.Linear(n_hidden, n_hidden), \n",
    "        NL, \n",
    "        nn.Linear(n_hidden, 2)\n",
    "    )\n",
    "    model.to(device)\n",
    "    with torch.no_grad():\n",
    "        Y = model(X).detach()\n",
    "    show_scatterplot(Y, colors, title='f(x)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dl-minicourse] *",
   "language": "python",
   "name": "conda-env-dl-minicourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
